/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module

prelude
public import Lean.Compiler.LCNF.InferType

public section

namespace Lean.Compiler.LCNF

/-- Helper class for lifting `CompilerM.codeBind` -/
class MonadCodeBind (m : Type → Type) where
  codeBind : (c : Code) → (f : FVarId → m Code) → m Code

/--
Return code that is equivalent to `c >>= f`. That is, executes `c`, and then `f x`, where
`x` is a variable that contains the result of `c`'s computation.

If `c` contains a jump to a join point `jp_i` not declared in `c`, we throw an exception because
an invalid block would be generated. It would be invalid because `f` would not
be applied to `jp_i`. Note that, we could have decided to create a copy of `jp_i` where we apply `f` to it,
by we decided to not do it to avoid code duplication.
-/
abbrev Code.bind [MonadCodeBind m] (c : Code) (f : FVarId → m Code) : m Code :=
  MonadCodeBind.codeBind c f

partial def CompilerM.codeBind (c : Code) (f : FVarId → CompilerM Code) : CompilerM Code := do
  go c |>.run {}
where
  go (c : Code) : ReaderT FVarIdSet CompilerM Code := do
    match c with
    | .let decl k => return .let decl (← go k)
    | .fun decl k => return .fun decl (← go k)
    | .jp decl k =>
      let value ← go decl.value
      let type ← value.inferParamType decl.params
      let decl ← decl.update' type value
      withReader (fun s => s.insert decl.fvarId) do
        return .jp decl (← go k)
    | .cases c =>
      let alts ← c.alts.mapM fun
        | .alt ctorName params k => return .alt ctorName params (← go k)
        | .default k => return .default (← go k)
      if alts.isEmpty then
        throwError "`Code.bind` failed, empty `cases` found"
      let resultType ← mkCasesResultType alts
      return .cases { c with alts, resultType }
    | .return fvarId => f fvarId
    | .jmp fvarId .. =>
      unless (← read).contains fvarId do
        throwError "`Code.bind` failed, it contains a out of scope join point"
      return c
    | .unreach type =>
      /-
      Create an auxiliary parameter `aux : type` to compute the resulting type of `f aux`.
      This code is not very efficient, we could ask caller to provide the type of `c >>= f`,
      but this is more convenient, and this case is seldom reached.
      -/
      let auxParam ← mkAuxParam type
      let k ← f auxParam.fvarId
      let typeNew ← k.inferType
      eraseCode k
      eraseParam auxParam
      return .unreach typeNew

instance : MonadCodeBind CompilerM where
  codeBind := CompilerM.codeBind

instance [MonadCodeBind m] : MonadCodeBind (ReaderT ρ m) where
  codeBind c f ctx := c.bind fun fvarId => f fvarId ctx

instance [STWorld ω m] [MonadCodeBind m] : MonadCodeBind (StateRefT' ω σ m) where
  codeBind c f sref := c.bind fun fvarId => f fvarId sref

/--
Create new parameters for the given arrow type.
Example: if `type` is `Nat → Bool → Int`, the result is
an array containing two new parameters with types `Nat` and `Bool`.
-/
partial def mkNewParams (type : Expr) : CompilerM (Array Param) :=
  go type #[] #[]
where
  go (type : Expr) (xs : Array Expr) (ps : Array Param) : CompilerM (Array Param) := do
    match type with
    | .forallE _ d b _ =>
      let d := d.instantiateRev xs
      let p ← mkAuxParam d
      go b (xs.push (.fvar p.fvarId)) (ps.push p)
    | _ =>
      let type := type.instantiateRev xs
      let type' := type.headBeta
      if type' != type then
        go type' #[] ps
      else
        return ps

def isEtaExpandCandidateCore (type : Expr) (params : Array Param) : Bool :=
  let typeArity := getArrowArity type
  let valueArity := params.size
  typeArity > valueArity

abbrev FunDecl.isEtaExpandCandidate (decl : FunDecl) : Bool :=
  isEtaExpandCandidateCore decl.type decl.params

def etaExpandCore (type : Expr) (params : Array Param) (value : Code) : CompilerM (Array Param × Code) := do
  let valueType ← instantiateForall type (params.map (mkFVar ·.fvarId))
  let psNew ← mkNewParams valueType
  let params := params ++ psNew
  let xs := psNew.map fun p => .fvar p.fvarId
  let value ← value.bind fun fvarId => do
    let auxDecl ← mkAuxLetDecl (.fvar fvarId xs)
    return .let auxDecl (.return auxDecl.fvarId)
  return (params, value)

def etaExpandCore? (type : Expr) (params : Array Param) (value : Code) : CompilerM (Option (Array Param × Code)) := do
  if isEtaExpandCandidateCore type params then
    etaExpandCore type params value
  else
    return none

def FunDecl.etaExpand (decl : FunDecl) : CompilerM FunDecl := do
  let some (params, value) ← etaExpandCore? decl.type decl.params decl.value | return decl
  decl.update decl.type params value

def Decl.etaExpand (decl : Decl) : CompilerM Decl := do
  match decl.value with
  | .code code =>
    let some (params, newCode) ← etaExpandCore? decl.type decl.params code | return decl
    return { decl with params, value := .code newCode}
  | .extern .. => return decl

end Lean.Compiler.LCNF
